{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c2eb1ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import sys\n",
    "import numpy as np\n",
    "# ^^^ pyforest auto-imports - don't write above this line\n",
    "try:\n",
    "    import bib_lookup\n",
    "except ModuleNotFoundError:\n",
    "    sys.path.insert(0, \"/home/wenhao/Jupyter/wenhao/workspace/bib_lookup/\")\n",
    "try:\n",
    "    from torch_ecg.utils.misc import MovingAverage, list_sum\n",
    "except ModuleNotFoundError:\n",
    "    sys.path.insert(0, \"/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/\")\n",
    "    from torch_ecg.utils.misc import MovingAverage, list_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "433ce4ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cfg import TrainCfg, TrainCfg_ns, ModelCfg, ModelCfg_ns\n",
    "from model import ECG_CRNN_CINC2021\n",
    "from dataset import CINC2021\n",
    "from trainer import CINC2021Trainer\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "from copy import deepcopy\n",
    "\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d75051e8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97808fe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "ECG_CRNN_CINC2021.__DEBUG__ = False\n",
    "CINC2021Trainer.__DEBUG__ = False\n",
    "CINC2021.__DEBUG__ = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28a2ba5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "TrainCfg_ns.db_dir = \"/home/wenhao/Jupyter/wenhao/data/CinC2021/\"\n",
    "TrainCfg.db_dir = \"/home/wenhao/Jupyter/wenhao/data/CinC2021/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6fc6d78",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7c11a2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_train = CINC2021(TrainCfg_ns, training=True, lazy=False)\n",
    "ds_val = CINC2021(TrainCfg_ns, training=False, lazy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5d7e1f7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b8e45d7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "cf77d6b3",
   "metadata": {},
   "source": [
    "## 12 lead, resnet_nature_comm_bottle_neck_se, 1-linear, AsymmetricLoss, lr=1e-4 to 2e-3, one cycle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cc5fec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_config = deepcopy(TrainCfg_ns)\n",
    "train_config.cnn_name = \"resnet_nature_comm_bottle_neck_se\"\n",
    "train_config.rnn_name = \"none\"\n",
    "train_config.attn_name = \"none\"\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "train_config.n_leads = len(train_config.leads)\n",
    "\n",
    "tranches = train_config.tranches_for_training\n",
    "if tranches:\n",
    "    classes = train_config.tranche_classes[tranches]\n",
    "else:\n",
    "    classes = train_config.classes\n",
    "\n",
    "if train_config.n_leads == 12:\n",
    "    model_config = deepcopy(ModelCfg_ns.twelve_leads)\n",
    "elif train_config.n_leads == 6:\n",
    "    model_config = deepcopy(ModelCfg_ns.six_leads)\n",
    "elif train_config.n_leads == 4:\n",
    "    model_config = deepcopy(ModelCfg_ns.four_leads)\n",
    "elif train_config.n_leads == 3:\n",
    "    model_config = deepcopy(ModelCfg_ns.three_leads)\n",
    "elif train_config.n_leads == 2:\n",
    "    model_config = deepcopy(ModelCfg_ns.two_leads)\n",
    "model_config.cnn.name = train_config.cnn_name\n",
    "model_config.rnn.name = train_config.rnn_name\n",
    "model_config.attn.name = train_config.attn_name\n",
    "model_config.clf = ED()\n",
    "model_config.clf.out_channels = [\n",
    "  # not including the last linear layer, whose out channels equals n_classes\n",
    "]\n",
    "model_config.clf.bias = True\n",
    "model_config.clf.dropouts = 0.0\n",
    "model_config.clf.activation = \"mish\"  # for a single layer `SeqLin`, activation is ignored"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2474dd27",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ECG_CRNN_CINC2021(\n",
    "    classes=train_config.classes,\n",
    "    n_leads=train_config.n_leads,\n",
    "    config=model_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46f765ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.module_size_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04631b32",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.device_count() > 1:\n",
    "    model = DP(model)\n",
    "    # model = DDP(model)\n",
    "model.to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5baf46c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = CINC2021Trainer(\n",
    "    model=model,\n",
    "    model_config=model_config,\n",
    "    train_config=train_config,\n",
    "    device=device,\n",
    "    lazy=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56ae29f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer._setup_dataloaders(ds_train, ds_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b9e4423",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8a160f4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6af90d9c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e485d139",
   "metadata": {},
   "source": [
    "## collect results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f079bf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "from matplotlib.pyplot import cm\n",
    "import matplotlib.patches as patches\n",
    "\n",
    "sns.set()\n",
    "\n",
    "plt.rcParams['xtick.labelsize']=28\n",
    "plt.rcParams['ytick.labelsize']=28\n",
    "plt.rcParams['axes.labelsize']=40\n",
    "plt.rcParams['legend.fontsize']=24\n",
    "\n",
    "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "\n",
    "markers = [\"p\", \"v\", \"s\", \"d\", \"x\", \"*\", \"+\", \"$\\heartsuit$\"]\n",
    "marker_size = 12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59398c22",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_ecg.utils.misc import MovingAverage, list_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7838f9ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ma = MovingAverage()\n",
    "ma_ea = MovingAverage()\n",
    "\n",
    "ma = lambda x: x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5e4476f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ms = pd.read_csv(\"./results/TorchECG_11-19_22-31_ECG_CRNN_CINC2021_adamw_amsgrad_LR_0.0001_BS_64_multi_scopic.csv\")\n",
    "df_ms_lw = pd.read_csv(\"./results/TorchECG_11-24_00-21_ECG_CRNN_CINC2021_adamw_amsgrad_LR_0.0001_BS_64_multi_scopic_leadwise.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91e82be7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ms_train = df_ms[df_ms.part==\"train\"].dropna(subset=[\"challenge_metric\"])\n",
    "df_ms_lw_train = df_ms_lw[df_ms_lw.part==\"train\"].dropna(subset=[\"challenge_metric\"])\n",
    "df_ms_val = df_ms[df_ms.part==\"val\"]\n",
    "df_ms_lw_val = df_ms_lw[df_ms_lw.part==\"val\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca40b5c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_lr = df_ms[[\"epoch\",\"loss\",\"lr\",\"part\",\"step\",]]\n",
    "df_lr.step = (df_lr.step/20)/53"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b0c6c16",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(20,12))\n",
    "\n",
    "line_width = 4\n",
    "\n",
    "lines = []\n",
    "\n",
    "spacing = 2\n",
    "loss_spacing = 16\n",
    "\n",
    "lines.append(ax.plot(\n",
    "    df_ms_train.epoch.values[::spacing], ma(df_ms_train.challenge_metric.values)[::spacing],\n",
    "    marker=markers[0], markersize=marker_size, linewidth=line_width, color=colors[0], ls=\"dashed\", label=\"train\",\n",
    "))\n",
    "lines.append(ax.plot(\n",
    "    df_ms_lw_train.epoch.values[::spacing], ma(df_ms_lw_train.challenge_metric.values)[::spacing],\n",
    "    marker=markers[1], markersize=marker_size, linewidth=line_width, color=colors[1], ls=\"dashed\", label=\"lw-train\",\n",
    "))\n",
    "lines.append(ax.plot(\n",
    "    df_ms_val.epoch.values[::spacing], ma(df_ms_val.challenge_metric.values)[::spacing],\n",
    "    marker=markers[0], markersize=marker_size, linewidth=line_width, color=colors[0], ls=(0,(1,1)), label=\"val\",\n",
    "))\n",
    "lines.append(ax.plot(\n",
    "    df_ms_lw_val.epoch.values[::spacing], ma(df_ms_lw_val.challenge_metric.values)[::spacing],\n",
    "    marker=markers[1], markersize=marker_size, linewidth=line_width, color=colors[1], ls=(0,(1,1)), label=\"lw-val\",\n",
    "))\n",
    "ax.set_ylim(0.35,1.05)\n",
    "ax.set_xlabel(\"Epochs (n.u.)\", fontsize=36)\n",
    "ax.set_ylabel(\"Challenge score (n.u.)\", fontsize=36)\n",
    "\n",
    "ax2 = ax.twinx()\n",
    "df_tmp = df_ms[(~df_ms.loss.isna())]\n",
    "df_tmp.step = (df_tmp.step/20)/53\n",
    "lines.append(ax2.plot(\n",
    "    df_tmp.step.values[::loss_spacing], ma_ea(df_tmp.loss.values)[::loss_spacing],\n",
    "    color=colors[0], ls=\"-\", linewidth=line_width, label=\"train-loss\",\n",
    "))\n",
    "df_tmp = df_ms_lw[(~df_ms_lw.loss.isna())]\n",
    "df_tmp.step = (df_tmp.step/20)/53\n",
    "lines.append(ax2.plot(\n",
    "    df_tmp.step.values[::loss_spacing], ma_ea(df_tmp.loss.values)[::loss_spacing],\n",
    "    color=colors[1], ls=\"-\", linewidth=line_width, label=\"lw-train-loss\",\n",
    "))\n",
    "ax2.set_ylabel(r\"Loss (n.u.)\")\n",
    "ax2.set_ylim(-0.03,0.39)\n",
    "ax2.set_yticks(np.arange(0,0.42,0.06).tolist())\n",
    "\n",
    "lns = list_sum(lines)\n",
    "labs = [l.get_label() for l in lns]\n",
    "ax.legend(\n",
    "    lns, labs,\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.5, -0.1), ncol=3,\n",
    "    fancybox=False, shadow=False, fontsize=30,\n",
    ")\n",
    "lr_line = ax.plot(\n",
    "    df_lr.step.values[::spacing], (df_lr.lr.values/df_lr.lr.max()/1.8+0.4)[::spacing],\n",
    "    linestyle=\":\", linewidth=6, color=colors[2],\n",
    ")\n",
    "ax.text(13,0.97, \"max lr = 0.02\", fontsize=30)\n",
    "ax.text(2,0.41, f\"start lr = {df_lr.lr.values[0]:.5f}\", fontsize=30)\n",
    "ax2.legend(lr_line, [\"learning rate\",], loc=\"upper right\", fontsize=30, bbox_to_anchor=(0.95, 1));\n",
    "\n",
    "plt.savefig(\"./results/cinc2021_nn_compare.svg\", dpi=1200, bbox_inches=\"tight\", transparent=False);\n",
    "plt.savefig(\"./results/cinc2021_nn_compare.pdf\", dpi=1200, bbox_inches=\"tight\", transparent=False);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69ba65d6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3bd9452",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11c0a22c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
